13. Training in PyTorch
Training in PyTorch
Once you've loaded a training dataset, next your job will be to define a CNN and train it to classify that set of images.
Loss and Optimizer
To train a model, you'll need to define how it trains by selecting a loss function and optimizer. These functions decide how the model updates its parameters as it trains and can affect how quickly the model converges, as well.
Learn more about loss functions and optimizers in the online documentation.
For a classification problem like this, one typically uses cross entropy loss, which can be defined in code like: criterion = nn.CrossEntropyLoss()
. PyTorch also includes some standard stochastic optimizers like stochastic gradient descent and Adam. You're encouraged to try different optimizers and see how your model responds to these choices as it trains.
Clasisification vs. Regression
The loss function you should choose depends on the kind of CNN you are trying to create; cross entropy is generally good for classification tasks, but you might choose a different loss function for, say, a regression problem that tried to predict (x,y) locations for the center or edges of clothing items instead of class scores.
Training the Network
Typically, we train any network for a number of epochs or cycles through the training dataset
Here are the steps that a training function performs as it iterates over the training dataset:
- Prepares all input images and label data for training
- Passes the input through the network (forward pass)
- Computes the loss (how far is the predicted classes are from the correct labels)
- Propagates gradients back into the network’s parameters (backward pass)
- Updates the weights (parameter update)
It repeats this process until the average loss has sufficiently decreased.
And in the next notebook, you'll see how to train and test a CNN for clothing classification, in detail.